'''


exec(''.join(open("/home/dalton_m/ppp/pyfiles/ANpppv11.py", encoding="utf8").readlines()[:]))
nohup python3 /home/dalton_m/ppp/pyfiles/ANpppv11.py  | tee &
'''


import os

import pandas as pd

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import sys

sys.path.append('/home/dalton_m/payload')
from basicfunctions import *


resultsloc1 = "/dataERS/eract/daltonm/results/ppp/"

filename = 'ANpppv11'

from datetime import date
datestr = date.today().strftime(format="%Y%m%d")
resultsloc = resultsloc1 + datestr + '/'
if not os.path.exists(resultsloc):
    os.makedirs(resultsloc)


logging.basicConfig(filename=resultsloc + filename + '.txt', level=logging.ERROR,
                    format='%(asctime)s - %(levelname)s - %(message)s')
logging.debug('This message should go to the log file')
logging.info('So should this')
logging.warning('And this, too')
logging.exception('And this, too')
logging.captureWarnings(True)

sys.path.append(resultsloc1 + 'pyfiles/')

sampledict = {
    # file name, sample size, seed number
    1: ['big', 3e6, 1],
    2: ['medium', 1e6, 2],
    3: ['small', 1e5, 3],
}

tlist = sampledict[3]
fsuff = tlist[0]
sampsize = tlist[1]
sampseed = tlist[2]

q = 8
moncount = q*3 + 1


'''
bringing this in to get PPP info
'''
kwvar = 'monthly_wage_avg_19'
df = cudf.read_csv(dataloc + 'pppfiles/' + 'ANppp_ldbv1' + '_step1.csv')
kcond = (df['aaemp_19'].notnull())
df = df[kcond].drop_duplicates(subset = 'ldb_num')
#######get ppp info

mc = ['ldb_num']
###ldb specific info
dfp = opennew('pppfiles/pppagg2020', ['LoanAmount', 'DateApproved',  'ldb_num']).drop_duplicates(mc)
df = df.merge(dfp, on = mc, how = 'left')

mc = ['ein']
dfp = opennew('pppfiles/pppagg2020', ['LoanAmount_ein', 'DateApproved_ein',  'ein']).drop_duplicates(mc)
df = df.merge(dfp, on = mc, how = 'left')
dfp = None


cond = (df['LoanAmount'] > 0).to_pandas()
df['Dppp'] = np.where(cond, 1, 0)
# dummy for ppp receipt at EIN level
cond = ((df['LoanAmount'] > 0) | (df['LoanAmount_ein'] > 0)).to_array()
df['Dppp_ein'] = np.where(cond, 1, 0)


####if EIN has less than $500 per employee, then treat it as missing
df['amt_peremp_ein'] = df['LoanAmount_ein'] / df['ein_aaemp_19']
cond = (df['amt_peremp_ein'] < 500).to_pandas()
df['LoanAmount_ein'] = np.where(cond, np.NaN, df['LoanAmount_ein'].to_pandas())
df['DateApproved_ein'] = np.where(cond, np.NaN, df['DateApproved_ein'].to_pandas())

###get amount and date approved
df['approval_date'] = pd.to_datetime(df['DateApproved_ein'].to_pandas().fillna(df['DateApproved'].to_pandas()), errors='coerce')
df['amount_final'] = df['LoanAmount_ein'].fillna(df['LoanAmount']).fillna(0)

cond = (df['amount_final'] > 0).to_pandas()
df['Dppp_final'] = np.where(cond, 1, 0)

#####


df['ppp_wgt'] = (df['aaemp_19'].to_pandas() / df['ein_aaemp_19'].to_pandas()).replace(np.inf,np.NaN).replace(np.NaN, 1)
cond = (df['ppp_wgt'] > 1).to_array()
df['ppp_wgt'] = np.where(cond, 1, df['ppp_wgt'].to_pandas())
#amount of EIN PPP money that is establishment-specific
df['pppamt_ein'] = df['ppp_wgt'] * df['amount_final']
#if amount is LDB amount, then it is already establishment specific
cond = (df['amount_final'] != df['LoanAmount_ein']).to_array()
df['pppamt_calc'] = np.where(cond, df['amount_final'].to_pandas(), np.NaN)
#if amount is EIN PPP amount, use proportioned EIN PPP money calculated above
cond = (df['amount_final'] == df['LoanAmount_ein']).to_array()
df['pppamt_calc'] = np.where(cond, df['pppamt_ein'].to_pandas(), df['pppamt_calc'].to_pandas())
#if PPP estab, get employment
cond = (df['pppamt_calc'] > 0).to_array()
df['emp_ppp'] = np.where(cond, df['aaemp_19'].to_pandas(), 0)
cond = (df['pppamt_calc'] > 0).to_array()
df['wage_ppp'] = np.where(cond, df[kwvar].to_pandas(), 0)
df['wage_ppp'] = df['wage_ppp'].replace(np.inf, np.NaN)

#all
dft2b = df[['emp_ppp', 'pppamt_calc', 'wage_ppp']].sum()
#keeping only employment greater than 1
kmcond = (df['emp_max_19'] > 1)
dft2 = df[kmcond][['emp_ppp', 'pppamt_calc', 'wage_ppp']].sum()

###get eligibility


import imageio
import matplotlib as mpl
import matplotlib.ticker as mtick
mpl.use('Agg')
import matplotlib.pyplot as plt
from pylab import *
import seaborn as sns
import matplotlib.font_manager as fm
mpl.rcParams['font.family']='serif'
cmfont = fm.FontProperties(fname=mpl.get_data_path() + '/fonts/ttf/cmr10.ttf')
mpl.rcParams['font.serif']=cmfont.get_name()
mpl.rcParams['mathtext.fontset']='cm'
mpl.rcParams['axes.unicode_minus']=False
sns.set_style('darkgrid', {'font.family' : 'serif',
                                   'font_serif':cmfont.get_name(),
                                   'mathtext.fontset' :'cm',
                                   'axes.unicode_minus' : False })


######### main results
def mkgraph(vals):
    hfont = {'fontname': 'Arial'}
    numresult = vals[0]
    groupval = vals[1]
    ptitle = vals[2]
    ylab = vals[3]
    fsuff = vals[4]
    ylim = vals[5]
    firstval = vals[6][0]
    lastval = vals[6][1]
    txt = vals[7]
    strt1 = txt.index('Dynamic Effects:')
    for i in range(0, numresult):
        strt1 = txt[strt1 + 5:].index('Dynamic Effects:') + strt1 + 5
    cols = 'event_time,ATT,std_err,95_ci_lb,95_ci_ub'.split(',')
    ####cut off stars at end - start at t = -5 (which is why using +3) and end at t = +7 which is why it's + 16
    vals = [i[:-2].strip().split(' ') for i in txt[strt1 + firstval : strt1 + lastval]]
    vals = [[j for j in i if j != ''] for i in vals]

    df = pd.DataFrame(vals, columns=cols)

    for c in ['ATT', 'std_err', '95_ci_lb']:
        df[c] = pd.to_numeric(df[c])

    df['std_err_simulat'] = abs(df['ATT'] - df['95_ci_lb'])
    dft = df.set_index('event_time')[['ATT', 'std_err_simulat']]
    fval = firstval - 7
    lval = lastval - 8
    dft = dft.loc[(pd.to_numeric(dft.index) >= fval) & (pd.to_numeric(dft.index) <= lval) ]
    if 'closed' in fsuff:
        dft = 100 * dft


    f, ax = plt.subplots(nrows=1, ncols=1, figsize=(9.5 * 1, 6 * 1), constrained_layout=True)
    plt.errorbar(x=dft.index, y=dft['ATT'], yerr=dft['std_err_simulat'], fmt='o', color='black',
                 ecolor='darkgray', elinewidth=3, capsize=0)
    plt.plot(dft['ATT'], color='darkblue')
    #plt.axhline(y=0, color = 'lightblue', linestyle = '-')
    tval = 1
    if 'closed' in fsuff:
        tval = -1
    plt.axhline(y=tval *.02, color='lightgreen', linestyle='-')
    plt.axhline(y=tval *-.02, color='lightcoral', linestyle='-')
    # plt.legend(loc='lower right')
    plt.ylim(ylim)
    plt.title("Effects of Receiving a PPP Loan on " + ptitle, fontsize=15)  # for title
    plt.xlabel("Months Until/Since PPP Approval", fontsize=10)  # label for x-axis
    plt.ylabel(ylab + "\n[Average Treatment on Treated]", fontsize=10)  # label for y-axis
    # handles, labels = ax.get_legend_handles_labels()
    ax.yaxis.set_major_formatter(mtick.PercentFormatter())

    legend_elements = [Line2D([0], [0], lw=3, color='darkgray', label='95% Simultaneous Confidence Band', markersize=8),
                       Line2D([0], [0], lw=3, color='darkblue', label='ATT of PPP Loan', markersize=8), ]
    if 'closed' in fsuff:
        f.legend(handles=legend_elements, fontsize=10, shadow=True, fancybox=True, bbox_to_anchor=(.65, 0.2),
                 loc="upper left", borderaxespad=0)
    else:
        f.legend(handles=legend_elements, fontsize=10, shadow=True, fancybox=True, bbox_to_anchor=(.65, 0.85),
                 loc="upper left", borderaxespad=0)

    plt.savefig(resultsloc + filename + '_'+ fsuff + '.pdf', format='pdf', dpi=1200)
    plt.savefig(resultsloc + filename + '_' + fsuff + '.png', format='png', dpi=600)
    plt.close()
    return(dft)

def mkgraph_pre(vals):
    hfont = {'fontname': 'Arial'}
    numresult = vals[0]
    groupval = vals[1]
    ptitle = vals[2]
    ylab = vals[3]
    fsuff = vals[4]
    ylim = vals[5]
    firstval = vals[6][0]
    lastval = vals[6][1]
    txt = vals[7]
    strt1 = txt.index('Dynamic Effects:')
    for i in range(0, numresult):
        strt1 = txt[strt1 + 5:].index('Dynamic Effects:') + strt1 + 5
    cols = 'event_time,ATT,std_err,95_ci_lb,95_ci_ub'.split(',')
    ####cut off stars at end - start at t = -5 (which is why using +3) and end at t = +7 which is why it's + 16
    vals = [i[:-2].strip().split(' ') for i in txt[strt1 + firstval:strt1 + lastval]]
    vals = [[j for j in i if j != ''] for i in vals]

    df = pd.DataFrame(vals, columns=cols)

    for c in ['ATT', 'std_err', '95_ci_lb']:
        df[c] = pd.to_numeric(df[c])

    df['std_err_simulat'] = abs(df['ATT'] - df['95_ci_lb'])
    dft = df.set_index('event_time')[['ATT', 'std_err_simulat']]
    fval = firstval - 7
    lval = lastval - 8
    dft = dft.loc[(pd.to_numeric(dft.index) >= fval) & (pd.to_numeric(dft.index) <= lval) ]
    if 'closed' in fsuff:
        dft = 100 * dft

    f, ax = plt.subplots(nrows=1, ncols=1, figsize=(9.5 * 1, 6 * 1), constrained_layout=True)

    plt.errorbar(x=dft.index, y=dft['ATT'], yerr=dft['std_err_simulat'], fmt='o', color='black',
                 ecolor='darkgray', elinewidth=0, capsize=0,alpha=0)
    plt.plot(dft['ATT'], color='darkblue', alpha=0)
    cond = (pd.to_numeric(df['event_time']) >= -1)
    for c in ['ATT', 'std_err_simulat']:
        df[c] = np.where(cond, np.NaN, df[c])
    dft = df.set_index('event_time')[['ATT', 'std_err_simulat']]
    if 'closed' in fsuff:
        dft = 100 * dft
    plt.errorbar(x=dft.index, y=dft['ATT'], yerr=dft['std_err_simulat'], fmt='o', color='black',
                 ecolor='darkgray', elinewidth=3, capsize=0)
    plt.plot(dft['ATT'], color='darkblue')
    #plt.axhline(y=0, color = 'lightblue', linestyle = '-')
    tval = 1
    if 'closed' in fsuff:
        tval = -1
    plt.axhline(y=tval *.02, color='lightgreen', linestyle='-')
    plt.axhline(y=tval *-.02, color='lightcoral', linestyle='-')
    # plt.legend(loc='lower right')
    plt.ylim(ylim)
    plt.title("Effects of Receiving a PPP Loan on " + ptitle, fontsize=15)  # for title
    plt.xlabel("Months Until/Since PPP Approval", fontsize=10)  # label for x-axis
    plt.ylabel(ylab + "\n[Average Treatment on Treated]", fontsize=10)  # label for y-axis
    # handles, labels = ax.get_legend_handles_labels()
    ax.yaxis.set_major_formatter(mtick.PercentFormatter())

    legend_elements = [Line2D([0], [0], lw=3, color='darkgray', label='95% Simultaneous Confidence Band', markersize=8),
                       Line2D([0], [0], lw=3, color='darkblue', label='ATT of PPP Loan', markersize=8), ]

    f.legend(handles=legend_elements, fontsize=10, shadow=True, fancybox=True, bbox_to_anchor=(.65, 0.85),
             loc="upper left", borderaxespad=0)
    plt.savefig(resultsloc + filename + '_'+ fsuff + '_pre.pdf', format='pdf', dpi=1200)
    plt.savefig(resultsloc + filename + '_' + fsuff + '_pre.png', format='png', dpi=600)
    plt.close()
    return(dft)


def multigraphplot(graphdict, numcols):
    numrows = math.ceil(len(graphdict) / numcols)
    graphnum = 0
    f, axes = plt.subplots(nrows=numrows, ncols=numcols, figsize=(9.5 * numcols, 6 * numrows), constrained_layout=True)
    dflist = {}
    for k, vals in graphdict.items():
        hfont = {'fontname': 'Arial'}
        hfont = {'fontname': 'Arial'}
        numresult = vals[0]
        groupval = vals[1]
        ptitle = vals[2]
        ylab = vals[3]
        fsuff = vals[4]
        ylim = vals[5]
        firstval = vals[6][0]
        lastval = vals[6][1]
        txt = vals[7]
        strt1 = txt.index('Dynamic Effects:')
        for i in range(0, numresult):
            strt1 = txt[strt1 + 5:].index('Dynamic Effects:') + strt1 + 5
        cols = 'event_time,ATT,std_err,95_ci_lb,95_ci_ub'.split(',')
        ####cut off stars at end - start at t = -5 (which is why using +3) and end at t = +7 which is why it's + 16
        ####cut off stars at end - start at t = -5 (which is why using +3) and end at t = +7 which is why it's + 16
        vals = [i[:-2].strip().split(' ') for i in txt[strt1 + firstval: strt1 + lastval]]
        vals = [[j for j in i if j != ''] for i in vals]

        df = pd.DataFrame(vals, columns=cols)

        for c in ['ATT', 'std_err', '95_ci_lb']:
            df[c] = pd.to_numeric(df[c])

        df['std_err_simulat'] = abs(df['ATT'] - df['95_ci_lb'])
        dft = df.set_index('event_time')[['ATT', 'std_err_simulat']]
        if 'closed' in fsuff:
            dft = 100 * dft
        dflist.update({k : dft})

        graphnum = graphnum + 1
        i = math.ceil(graphnum / numcols) - 1
        j = (graphnum - 1) % numcols

        if ((numrows == 1) | (numcols == 1)):
            axes[i].errorbar(x=dft.index, y=dft['ATT'], yerr=dft['std_err_simulat'], fmt='o', color='black',ecolor='darkgray', elinewidth=3, capsize=0)
            axes[i].plot(dft['ATT'], color='darkblue')
            # plt.axhline(y=0, color = 'lightblue', linestyle = '-')
            tval = 1
            if 'closed' in fsuff:
                tval = -1
            axes[i].axhline(y=tval * .02, color='lightgreen', linestyle='-')
            axes[i].axhline(y=tval * -.02, color='lightcoral', linestyle='-')
            # plt.legend(loc='lower right')
            axes[i].set_ylim(ylim)
            axes[i].set_title("Effects of Receiving a PPP Loan on " + ptitle, fontsize=15)  # for title
            axes[i].set_xlabel("Months Until/Since PPP Approval", fontsize=10)  # label for x-axis
            axes[i].set_ylabel(ylab + "\n[Average Treatment on Treated]", fontsize=10)  # label for y-axis
            # handles, labels = ax.get_legend_handles_labels()
            axes[i].yaxis.set_major_formatter(mtick.PercentFormatter())

            legend_elements = [
                Line2D([0], [0], lw=3, color='darkgray', label='95% Simultaneous Confidence Band', markersize=8),
                Line2D([0], [0], lw=3, color='darkblue', label='ATT of PPP Loan', markersize=8), ]
            if 'closed' in fsuff:
                axes[i].legend(handles=legend_elements, fontsize=10, shadow=True, fancybox=True, bbox_to_anchor=(.65, 0.2),
                         loc="upper left", borderaxespad=0)
            else:
                axes[i].legend(handles=legend_elements, fontsize=10, shadow=True, fancybox=True, bbox_to_anchor=(.65, 0.85),
                         loc="upper left", borderaxespad=0)

        else:
            axes[i,j].errorbar(x=dft.index, y=dft['ATT'], yerr=dft['std_err_simulat'], fmt='o', color='black',
                             ecolor='darkgray', elinewidth=3, capsize=0)
            axes[i,j].plot(dft['ATT'], color='darkblue')
            # plt.axhline(y=0, color = 'lightblue', linestyle = '-')
            tval = 1
            if 'closed' in fsuff:
                tval = -1
            axes[i,j].axhline(y=tval * .02, color='lightgreen', linestyle='-')
            axes[i,j].axhline(y=tval * -.02, color='lightcoral', linestyle='-')
            # plt.legend(loc='lower right')
            axes[i,j].set_ylim(ylim)
            axes[i,j].set_title("Effects of Receiving a PPP Loan on " + ptitle, fontsize=15)  # for title
            axes[i,j].set_xlabel("Months Until/Since PPP Approval", fontsize=10)  # label for x-axis
            axes[i,j].set_ylabel(ylab + "\n[Average Treatment on Treated]", fontsize=10)  # label for y-axis
            # handles, labels = ax.get_legend_handles_labels()
            axes[i,j].yaxis.set_major_formatter(mtick.PercentFormatter())

            legend_elements = [
                Line2D([0], [0], lw=3, color='darkgray', label='95% Simultaneous Confidence Band', markersize=8),
                Line2D([0], [0], lw=3, color='darkblue', label='ATT of PPP Loan', markersize=8), ]

            if 'closed' in fsuff:
                axes[i,j].legend(handles=legend_elements, fontsize=10, shadow=True, fancybox=True, bbox_to_anchor=(.65, 0.2),
                         loc="upper left", borderaxespad=0)
            else:
                axes[i,j].legend(handles=legend_elements, fontsize=10, shadow=True, fancybox=True, bbox_to_anchor=(.65, 0.85),
                         loc="upper left", borderaxespad=0)


    # if graphnum<2:
    #     f.delaxes(axes[numrows - 1, numcols - 1])
    #     f.delaxes(axes[numrows - 1, 0])
    if graphnum%numcols != 0:
        for i in range(0, numcols - graphnum%numcols):
            f.delaxes(axes[numrows - 1, numcols - 1 - i])
    plt.savefig(resultsloc + filename + '_combined_' + fsuff + '.pdf', format='pdf', dpi=1200)
    plt.savefig(resultsloc + filename + '_combined_' + fsuff + '.png', format='png', dpi=600)
    plt.close()
    return(dflist)


ylab1 = 'Effect on Employment %\n Relative to Pre-Pandemic Employment Baseline'
ylab2 = 'Effect on Probability of Closing'
ylab3 = 'Effect on Wages %\n Relative to  Pre-Pandemic Wage Baseline'
filesuff1 = 'emp'
filesuff2 = 'closed'
filesuff3 = 'wage'
ylim1 = (-3.5,12)
ylim2 = (-8.5,2)
ylim3 = (-6,17)

'''
****************************************************************
****************************************************************
main results + pre graphs
****************************************************************
****************************************************************
'''

firstval = 2
lastval = 23

graphdict = {
    #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
    # -1 : [6, 'Overall Effects', '\nEmployment Relative to 2019 Employment', ylab1,  filesuff1 + '_nocontrols', ylim1 ],
    0 : [0, 'Overall Effects', '\nEmployment Relative to Pre-Pandemic Employment Baseline', ylab1,  filesuff1, ylim1,
[firstval, lastval],
         open(mostrecentfile('ANldb_pppv7.txt', resultsloc1)).read().split('\n')],
# -2 : [7, 'Overall Effects', '\nProbability of Establishment Closure', ylab2,  filesuff2 + '_nocontrols', ylim2 ],
1 : [2, 'Overall Effects', '\nProbability of Establishment Closure', ylab2,  filesuff2, ylim2,
[firstval, lastval],
         open(mostrecentfile('ANldb_pppv7.txt', resultsloc1)).read().split('\n')],
5 : [1, 'Overall Effects', '\nWages Relative to Pre-Pandemic Wage Baseline', ylab3,  filesuff3, ylim3,
[firstval, lastval],
         open(mostrecentfile('ANldb_pppv7.txt', resultsloc1)).read().split('\n')],
}

dflist = {}
for k,vals in graphdict.items():
    dft = {k: mkgraph_pre( vals)}
    dflist.update({k : mkgraph( vals)})




'''
**************************************************************************
**************************************************************************
baseline1
**************************************************************************
'''
graphdict = {
    #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim, filename]
    ### without controls
    0: [3, 'Overall Effects', '\nEmployment Relative to Pre-Pandemic Employment Baseline', ylab1, filesuff1 + '_nocontrol', ylim1, [firstval, lastval], open(mostrecentfile('ANldb_pppv7.txt', resultsloc1, -1)).read().split('\n')],

###with controls
1 : [0, 'Overall Effects', '\nEmployment Relative to Pre-Pandemic Employment Baseline', ylab1,  filesuff1 + '', ylim1, [firstval, lastval], open(mostrecentfile('ANldb_pppv7.txt', resultsloc1, -1)).read().split('\n') ],

###removes multi-programs
3 : [0, 'Overall Effects', '\nEmployment Relative to Pre-Pandemic Employment Baseline', ylab1,  filesuff1 + '_clean', ylim1, [firstval, lastval], open(mostrecentfile('ANldb_pppv7clean.txt', resultsloc1, -1)).read().split('\n') ],

###employment cutoff
5 : [0, 'Overall Effects', '\nEmployment Relative to Pre-Pandemic Employment Baseline', ylab1,  filesuff1 + '_noclose', ylim1, [firstval, lastval], open(mostrecentfile('ANldb_pppv7_noclose.txt', resultsloc1, -1)).read().split('\n') ],
###employment cutoff
4 : [0, 'Overall Effects', '\nEmployment Relative to Pre-Pandemic Employment Baseline', ylab1,  filesuff1 + '_cutoff', ylim1, [firstval, lastval], open(mostrecentfile('ANldb_pppv7_cutoff.txt', resultsloc1, -1)).read().split('\n') ],

}



dflist = {}
for k,vals in graphdict.items():
    dflist.update({k : mkgraph(vals)})



###latex table code
dft = pd.concat(list(dflist.values()), axis=1)
ncols = len(dft.columns)
dft.columns = [i for i in range(ncols)]
dft.reset_index(inplace=True)
txt1 = "\multirow{2}*{"

for i in range(len(dft)):
    txt = txt1
    rown = dft.iloc[i]['event_time']
    txt = txt + str(rown) + "}&ATT&"
    for j in range(0,ncols,2):
        tval = str(round(dft.iloc[i][j],2)).replace('nan', '-')
        txt = txt + tval + "&"
    print(txt[:-1] + r"\\")
    txt = "&[95\% C.I.]&"
    for j in range(0,ncols):
        if j%2==1:
            tvallb = str(round(dft.iloc[i][j-1] - dft.iloc[i][j],1)).replace('nan', '-')
            tvalub = str(round(dft.iloc[i][j - 1] + dft.iloc[i][j],1)).replace('nan', '-')
            tval = "[" + tvallb + ', '+tvalub + "]"
            txt = txt + tval + "&"
    print(txt[:-1] + r"\\")





'''
########################################################################
########################################################################

2019
########################################################################
########################################################################
'''

graphdict = {
    #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
    4: [0, 'Overall Effects', '\nEmployment Relative to Pre-Pandemic Employment Baseline', ylab1, filesuff1 + '_2019',
        ylim1, [firstval, 16], open(mostrecentfile('ANldb_pppv7_2019.txt', resultsloc1, -1)).read().split('\n')],
}

dflist = {}
for k,vals in graphdict.items():
    dft = {k: mkgraph_pre( vals)}
    dflist.update({k : mkgraph( vals)})



'''
########################################################################
########################################################################
closures
########################################################################
########################################################################
'''

graphdict = {
    #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
    4: [2, 'Overall Effects', '\nEmployment Relative to Pre-Pandemic Employment Baseline', ylab2, filesuff2 + '',
        ylim2, [firstval, lastval], open(mostrecentfile('ANldb_pppv7.txt', resultsloc1, -1)).read().split('\n')],
}

dflist = {}
for k,vals in graphdict.items():
    dft = {k: mkgraph_pre( vals)}
    dflist.update({k : mkgraph( vals)})


'''
CES1
'''
# ylab1a = "Effect on Employment %\n Relative to January 2020 Employment"
# ylab4 = 'Effect on Hours %\n Relative to January 2020 Hours'
# filesuff4 = 'hours'
# ylim4 = (-7,9)
# ylab5 = 'Effect on Hours per Employee %\n Relative to January 2020 Hours per Employee'
# filesuff5 = 'hpe'
# ylim5 = (-5,5)
# graphdict = {
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#
# 0: [1, 'Overall Effects', '\nEmployment Relative to January 2020\nUsing Current Employment Statistics survey', ylab4,  filesuff1 + 'ces', ylim1 ],
#     1: [4, 'Overall Effects', '\nHours Worked Relative to January 2020', ylab4,  filesuff4, ylim4 ],
#     #1: [6, '"Cutoff Sample"', '\nHours Worked Relative to January 2020', ylab4, filesuff4 + '_cutoff', ylim4],
#     2: [6, 'Overall Effects', '\nHours per Employee Relative to January 2020', ylab5, filesuff5, ylim5],
#     #3: [10, '"Cutoff Sample"', '\nHours per Employee Relative to January 2020', ylab5, filesuff5 + '_cutoff', ylim5],
#
# }
# fname = mostrecentfile('ANces_pppv14.txt', resultsloc1, -1)
# txt = open(fname).read().split('\n')
# dflist = {}
# for k,vals in graphdict.items():
#     dflist.update({k : mkgraph(txt, vals)})
#
#
#
# ###latex table code
# dft = pd.concat(list(dflist.values()), axis=1)
# ncols = len(dft.columns)
# dft.columns = [i for i in range(ncols)]
# dft.reset_index(inplace=True)
# txt1 = "\multirow{2}*{"
#
# for i in range(len(dft)):
#     txt = txt1
#     rown = dft.iloc[i]['event_time']
#     txt = txt + str(rown) + "}&ATT&"
#     for j in range(0,ncols,2):
#         tval = str(round(dft.iloc[i][j],2))
#         txt = txt + tval + "&"
#     print(txt[:-1] + r"\\")
#     txt = "&[95\% C.I.]&"
#     for j in range(0,ncols):
#         if j%2==1:
#             tvallb = str(round(dft.iloc[i][j-1] - dft.iloc[i][j],1))
#             tvalub = str(round(dft.iloc[i][j - 1] + dft.iloc[i][j],1))
#             tval = "[" + tvallb + ', '+tvalub + "]"
#             txt = txt + tval + "&"
#     print(txt[:-1] + r"\\")

'''
cutoff results
 '''
# ylim1c = (-5.5,11.1)
#
# fname = mostrecentfile('ANldb_pppv7_cutoff.txt', resultsloc1)
# txt = open(fname).read().split('\n')
#
# ylim3c = (-7,15)
# graphdict = {
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#
#     1: [0, 'Overall Effects', '\nEmployment Relative to Pre-Pandemic Employment Baseline\nRestricting Sample to Employers Near SBA Cutoff', ylab1,  filesuff1 + 'cutoff', ylim1c, [firstval, 16], open(mostrecentfile('ANldb_pppv7.txt', resultsloc1, -1)).read().split('\n') ],
# # 0 : [8, 'Overall Effects', '\nProbability of Establishment Closure', ylab2,  filesuff2 + 'cutoff', ylim2, firstval, txt ],
# # 2 : [6, 'Overall Effects', '\nWages Relative to Pre-Pandemic Employment Baseline\nRestricting Sample to Employers Near SBA Cutoff', ylab3,  filesuff3 + 'cutoff', ylim3c, firstval, txt ],
# }
#
# dflist = {}
# for k,vals in graphdict.items():
#     dflist.update({k : mkgraph( vals)})





'''
2019 results
'''
# firstval = 4
# graphdict = {
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#
#     1: [1, 'Overall Effects', '\nEmployment Relative to Pre-2018 Employment Baseline\nPre-Pandemic Placebo Test', ylab1,  filesuff1 + '_2019', ylim1,
#         firstval,
#         ],
#     0: [0, 'Overall Effects', '\nWage Relative to Pre-2018 Wage Baseline\nPre-Pandemic Placebo Test',
#         ylab3, filesuff3 + '_2019', ylim3],
# # 0 : [1, 'Overall Effects', '\nProbability of Establishment Closure', ylab2,  filesuff2 + 'cutoff', ylim2 ],
# # 2 : [2, 'Overall Effects', '\nWages Relative to 2019 Wages', ylab3,  filesuff3 + 'cutoff', ylim3 ],
# }
# fname = mostrecentfile('ANldb_pppv5_2019.txt', resultsloc1)
# txt = open(fname).read().split('\n')
# dflist = {}
# for k,vals in graphdict.items():
#     dflist.update({k : mkgraph(txt, vals)})



###latex table code
# dft = pd.concat(list(dflist.values()), axis=1)
# ncols = len(dft.columns)
# dft.columns = [i for i in range(ncols)]
# dft.reset_index(inplace=True)
# txt1 = "\multirow{2}*{"
#
# for i in range(len(dft)):
#     txt = txt1
#     rown = dft.iloc[i]['event_time']
#     txt = txt + str(rown) + "}&ATT&"
#     for j in range(0,ncols,2):
#         tval = str(round(dft.iloc[i][j],2))
#         txt = txt + tval + "&"
#     print(txt[:-1] + r"\\")
#     txt = "&[95\% C.I.]&"
#     for j in range(0,ncols):
#         if j%2==1:
#             tvallb = str(round(dft.iloc[i][j-1] - dft.iloc[i][j],1))
#             tvalub = str(round(dft.iloc[i][j - 1] + dft.iloc[i][j],1))
#             tval = "[" + tvallb + ', '+tvalub + "]"
#             txt = txt + tval + "&"
#     print(txt[:-1] + r"\\")


'''
baseline2
'''
# #########change to recent file
# graphdict = {
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim, filename]
#
# 1 : [2, 'Overall Effects', '\nClosure Status in Month', ylab2,  filesuff2 + '', ylim2, mostrecentfile('ANldb_pppv5.txt', resultsloc1, -1) ],
#
# 2 : [0, 'Overall Effects', '\nMonthly Wages Relative to Pre-Pandemic Wage Baseline', ylab3,  filesuff3 + '', ylim3, mostrecentfile('ANldb_pppv5.txt', resultsloc1, -1) ],
#
# }
#
# # dflist = {}
# # for k,vals in graphdict.items():
# #     if k<=3:
# #         fname = vals[6]
# #         txt = open(fname).read().split('\n')
# #         dflist.update({k : mkgraph(txt, vals)})
# #     else:
# #         fname = vals[6]
# #         txt = open(fname).read().split('\n')
# #         dflist.update({k : mkgraph(txt, vals)})
#
# numcols = 1
# dflist = multigraphplot(graphdict, numcols)
#
# ###latex table code
# dft = pd.concat(list(dflist.values()), axis=1)
# ncols = len(dft.columns)
# dft.columns = [i for i in range(ncols)]
# dft.reset_index(inplace=True)
# txt1 = "\multirow{2}*{"
#
# for i in range(len(dft)):
#     txt = txt1
#     rown = dft.iloc[i]['event_time']
#     txt = txt + str(rown) + "}&ATT&"
#     for j in range(0,ncols,2):
#         tval = str(round(dft.iloc[i][j],2))
#         txt = txt + tval + "&"
#     print(txt[:-1] + r"\\")
#     txt = "&[95\% C.I.]&"
#     for j in range(0,ncols):
#         if j%2==1:
#             tvallb = str(round(dft.iloc[i][j-1] - dft.iloc[i][j],1))
#             tvalub = str(round(dft.iloc[i][j - 1] + dft.iloc[i][j],1))
#             tval = "[" + tvallb + ', '+tvalub + "]"
#             txt = txt + tval + "&"
#     print(txt[:-1] + r"\\")
#

'''
CES2
'''
# ylab3a = "Effect on Wage %\n Relative to January 2020 Wages"
# ylab6 = 'Effect on Pay per Hour %\n Relative to January 2020 Pay per Hour'
# filesuff6 = 'wph'
# ylim6 = (-5,5)
# graphdict = {
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#
# 0: [0, 'Overall Effects', '\nEmployment Relative to January 2020\nUsing Current Employment Statistics survey', ylab3a,  filesuff3 + 'ces', ylim3 ],
#     2: [5, 'Overall Effects', '\nPay per Hour Relative to January 2020', ylab6, filesuff6, ylim6],
#
# }
# fname = mostrecentfile('ANces_pppv14.txt', resultsloc1, -1)
# txt = open(fname).read().split('\n')
# dflist = {}
# for k,vals in graphdict.items():
#     dflist.update({k : mkgraph(txt, vals)})



###latex table code
dft = pd.concat(list(dflist.values()), axis=1)
ncols = len(dft.columns)
dft.columns = [i for i in range(ncols)]
dft.reset_index(inplace=True)
txt1 = "\multirow{2}*{"

for i in range(len(dft)):
    txt = txt1
    rown = dft.iloc[i]['event_time']
    txt = txt + str(rown) + "}&ATT&"
    for j in range(0,ncols,2):
        tval = str(round(dft.iloc[i][j],2))
        txt = txt + tval + "&"
    print(txt[:-1] + r"\\")
    txt = "&[95\% C.I.]&"
    for j in range(0,ncols):
        if j%2==1:
            tvallb = str(round(dft.iloc[i][j-1] - dft.iloc[i][j],1))
            tvalub = str(round(dft.iloc[i][j - 1] + dft.iloc[i][j],1))
            tval = "[" + tvallb + ', '+tvalub + "]"
            txt = txt + tval + "&"
    print(txt[:-1] + r"\\")

'''
ui state
'''
# ylim1a = (-3, 10)
# ylim3a = (-4, 14)
# fname = mostrecentfile('ANldb_pppv5_uistate.txt', resultsloc1, -1)
# graphdict = {
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#
#     1: [1, 'Overall Effects', '\nEmployment Relative to Pre-Pandemic Employment Baseline\nRestricting Sample to Low UI Replacement Rate States', ylab1,  filesuff1 + 'ui_state', ylim1a, fname, 4 ],
#     2: [3, 'Overall Effects',
#         '\nEmployment Relative to Pre-Pandemic Employment Baseline\nRestricting Sample to High UI Replacement Rate States', ylab1,
#         filesuff1 + 'ui_state', ylim1a, fname, 4],
#     3: [0, 'Overall Effects',
#         '\nWages Relative to Pre-Pandemic Wage Baseline\nRestricting Sample to Low UI Replacement Rate States', ylab3,
#         filesuff3 + 'ui_state', ylim3a, fname, 4],
#     4: [2, 'Overall Effects',
#         '\nWages Relative to Pre-Pandemic Wage Baseline\nRestricting Sample to High UI Replacement Rate States', ylab3,
#         filesuff3 + 'ui_state', ylim3a, fname, 4],
#
# }
#
# #txt = open(fname).read().split('\n')
# # dflist = {}
# # for k,vals in graphdict.items():
# #     dflist.update({k : mkgraph(txt, vals)})
#
# numcols = 2
# dflist = multigraphplot(graphdict, numcols)
#


'''
this is for the employee months saved

pppperemp2
pppperemp_presentation
'''
fname = mostrecentfile('ANldb_pppv7.txt', resultsloc1, -1)
txt = open(fname).read().split('\n')
graphdict = {
    #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
    0: [0, 'Overall Effects', '\nEmployment Relative to Pre-Pandemic Employment Baseline', ylab1,  filesuff1, ylim1,
[firstval, lastval],
         txt],

5 : [1, 'Overall Effects', '\nWages Relative to Pre-Pandemic Wage Baseline', ylab3,  filesuff3, ylim3, [firstval, lastval], txt ],
}

######change most recent file
dflist = {}
for k,vals in graphdict.items():
    dflist.update({k : mkgraph(vals)})

txt1 = ""
txtlist = []
totemp = 0
totwage = 0
###change the second number to get last month
###want most recent [15], but may also for matching with draft 1, use 7
eventrange = [str(i) for i in range(0,16)]
for i in eventrange:
    rlist = []
    for k,dft in dflist.items():
        dft = dft.reset_index().copy()
        dft = dft[dft['event_time'].isin(eventrange)]
        #dft['LB'] = ((dft['ATT'] - dft['std_err_simulat']) * dft2['emp_ppp']) / 100
        if k == 0:
            pref = 'emp'
        elif k==5:
            pref = 'wage'
        #ATT times PPP establishments
        dft['exact'] =  ((dft['ATT'] ) * dft2[pref + '_ppp']) / 100
        #dft['UB'] = ((dft['ATT'] + dft['std_err_simulat']) * dft2['emp_ppp']) / 100
        ###latex table code
        rown = i
        rlist.append("$ATT_{" + str(rown) + "}$ & ")
        tval = str(round(dft.iloc[int(i)]['ATT'], 2))
        rlist.append(tval + "&")
        tval = "{:,}".format(int(round(dft.iloc[int(i)]['exact'], 0)))
        rlist.append(tval + "&")
        if rown =='0':
            print(("{:,}".format(int(round(dft['exact'].sum(), 0)))))
            if pref == 'emp':
                print(("{:,}".format(int(round(dft2['pppamt_calc'] / dft['exact'].sum(), 0)))))
            elif pref == 'wage':
                print(("{:,}".format(round(dft2['pppamt_calc'] / dft['exact'].sum(), 2))))
    txt = rlist[0] + rlist[1] + rlist[4] + rlist[2] + rlist[5]
    print(txt[:-1] + r"\\")
    txtlist.append(txt)




'''
employee months and wage saved

might not need this anymore
'''
# graphdict = {
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#     1: [8, 'Overall Effects', '\nEmployment Relative to 2019 Employment', ylab1,  filesuff1 + 'cutoff', ylim1 ],
#
# 2 : [1, 'Overall Effects', '\nWages Relative to 2019 Wages', ylab3,  filesuff3 + 'cutoff', ylim3 ],
# }
# fname = mostrecentfile('ANldb_pppv5_cutoff.txt', resultsloc1)
# txt = open(fname).read().split('\n')
# dflist = {}
# for k,vals in graphdict.items():
#     if k==1:
#         fname = mostrecentfile('ANldb_pppv5.txt', resultsloc1)
#         txt = open(fname).read().split('\n')
#     elif k==2:
#         fname = mostrecentfile('ANldb_pppv5_wages.txt', resultsloc1)
#         txt = open(fname).read().split('\n')
#     dflist.update({k : mkgraph(txt, vals)})
#
#
# dft = dflist[1].iloc[5:]
# dft['exact'] =  ((dft['ATT'] ) * dft2['emp_ppp']) / 100
# dft.reset_index(inplace=True)
# dftb = dflist[1].iloc[5:]
# dftb['exact_wage'] =  ((dftb['ATT'] ) * dft2['wage_ppp']) / 100
# dftb.reset_index(inplace=True)
# dft = pd.concat([dft,dftb], axis=1)
#
# ###latex table code
# txt1 = ""
#
# #final sum
# txt = r"\$ of PPP Loans per...&"
# dft3 = dft2['pppamt_calc'] / dft[['exact', 'exact_wage']].sum()
# tval = "\$" + "{:,}".format(int(round(dft3['exact'], 0)))
# txt = txt + tval + "&"
# tval = "\$" + str(round(dft3['exact_wage'], 2))
# txt = txt + tval + "&"
#
# print(txt[:-1] + r"\\")


'''
heterogeneity analysis
'''


def hetero_latex(graphdict, vardict, gvar):
    dflist = {}
    for k, vals in graphdict.items():
        dflist.update({k: mkgraph(vals)})

    kvars = ['emp_ppp', 'pppamt_calc', 'wage_ppp', 'aaemp_19', kwvar]
    #### get ppp and emp group totals from main file
    dft = df[kvars + [gvar]].to_pandas().groupby(gvar, as_index=False)[kvars].sum()
    dft['tot_ppp'] = dft['pppamt_calc'].sum()
    dft['pct_pppamt'] = 100 * dft['pppamt_calc'] / dft['tot_ppp']
    dft['pct_emp_ppp'] = 100 * dft['emp_ppp'] / dft['aaemp_19']
    dft['pct_wage_ppp'] = 100 * dft['wage_ppp'] / (dft[kwvar])

    resultsdict = {}
    for k, dftb in dflist.items():
        try:
            gval = float(k.split('_')[0])
            dft[gvar] = pd.to_numeric(dft[gvar])
        except:
            gval = k.split('_')[0]
        varval = k.split('_')[1]
        dftb = dftb.iloc[5:]
        dftb['exact'] = ((dftb['ATT']) * dft[dft[gvar] == gval][varval + '_ppp'].values[0]) / 100
        ##give PPP per emp / dollar wage
        kval = dft[dft[gvar] == gval]['pppamt_calc'].values[0] / dftb['exact'].sum()
        # pct of PPP money going to group
        kval1 = dft[dft[gvar] == gval]['pct_pppamt'].values[0]
        # pct of employment in group receiving PPP
        kval2 = dft[dft[gvar] == gval]['pct_emp_ppp'].values[0]
        # total emp/wages retained
        kval3 = dftb['exact'].sum()
        # total PPP
        kval4 = dft[dft[gvar] == gval]['pppamt_calc'].values[0]
        # pct of wages in group at estab receiving PPP
        #kval3 = dft[dft[gvar] == gval]['pct_wage_ppp'].values[0]
        rlist = [kval, kval1, kval2, kval3, kval4]
        resultsdict.update({k: rlist})
        print(k)
        print(dftb)

    ###latex table code
    txt1 = ""
    for k, v in vardict.items():
        txt = v + r"&"
        ###$ per emp
        tval = resultsdict[str(k) + '_emp'][0]
        if tval > 100:
            tval = "\$" + "{:,}".format(int(round(tval, 0)))
        else:
            tval = "\$" + str(round(tval, 2))
        txt = txt + tval + "&"
        ##$ per wage
        tval = resultsdict[str(k) + '_wage'][0]
        if tval > 100:
            tval = "\$" + "{:,}".format(int(round(tval, 0)))
        else:
            tval = "\$" + str(round(tval, 2))
        txt = txt + tval + "&"
        ##all other
        tvals = resultsdict[str(k) + '_emp'][1:3]
        for tval in tvals:
            if tval > 100:
                tval = "" + "{:,}".format(int(round(tval, 0))) + '\%'
            elif tval > 10:
                tval = "" + "{:,}".format(int(round(tval, 1))) + '\%'
            else:
                tval = "" + str(round(tval, 2)) + '\%'
            txt = txt + tval + "&"
        print(txt[:-1] + r"\\")
    ####this is for the second set of results that weights an effect for employee months and wages
    print(r"\hline")
    print(r"\hline")
    print(r"\vspace{.25in}\\")
    print(r"&\multicolumn{4}{c}{\$ of PPP Loans per...}\\")
    print(r"&\multicolumn{2}{c}{Employee-Month Retained }&\multicolumn{2}{c}{Dollar-Wage Retained}\\")
    print(r"&\multicolumn{2}{c}{(5)}&\multicolumn{2}{c}{(6)}\\")
    ###constructing aggregation
    ktval = 0
    ktval_wage = 0
    #need this because not every group sums to 100, need to reweight
    tval2sum = 0
    for k, v in vardict.items():
        ### $ per emp
        #total emp saved
        tval = resultsdict[str(k) + '_emp'][3]
        #total PPP
        tval2 = resultsdict[str(k) + '_emp'][4]
        if tval < 0:
            tval = 0
        ktval = ktval + tval
        # total wage saved
        tval = resultsdict[str(k) + '_wage'][3]
        if tval < 0:
            tval = 0
        ktval_wage = ktval_wage + tval
        tval2sum = tval2sum + tval2
    ktval = tval2sum / ktval
    ktval_wage = tval2sum / ktval_wage
    ktval = "\$" + "{:,}".format(int(round(ktval, 0)))
    ktval_wage = "\$" + str(round(ktval_wage, 2))
    print(r"&\multicolumn{2}{c}{"+ktval+"}&\multicolumn{2}{c}{"+ktval_wage+r"}\\")

    return


ylim1b = (-7,17.5)
ylim2b = (-13.5,3.5)
ylim3b = (-4.5,15)

'''
size estimates

'''
bins = [-9e9, 1, 5, 10, 25, 100, 9e9]
labs = [1,2,3,4,5,6]

df['size_group'] = pd.cut(df['ein_aaemp_19'].to_pandas(), bins=bins, labels=labs)

############ wages
gvar = 'size_group'
vardict = {
    1 : '1',
2 : '2-5',
3 : '6-10',
4 : '11-25',
5 : '26-100',
6 : '100-500',

}
groupname = 'Firm Size Group'
groupvar = '\nby '+ groupname +' : '
filesuff1b = filesuff1 + gvar
filesuff2b = filesuff2 + gvar
filesuff3b = filesuff3 + gvar

fname = mostrecentfile('ANldb_pppv7_size.txt', resultsloc1, -1)
txt = open(fname).read().split('\n')
graphdict = {
    #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
    str(i) + '_emp': [j*3,
                      'Overall Effects',
                      '\nEmployment Relative to Pre-Pandemic Employment Baseline' + groupvar + vardict[i],
                      ylab1,
                      filesuff1b + str(i),
                      ylim1b,
[firstval, lastval],
                      txt ] for j,i in enumerate(list(vardict.keys()))
}

numcols = 2
dflist = multigraphplot(graphdict, numcols)


graphdict.update({
    #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
    str(i) + '_wage': [j*3 + 1,
                       'Overall Effects',
                       '\nWages Relative to Pre-Pandemic Wage Baseline' + groupvar + vardict[i],
                       ylab3,
                       filesuff3b + str(i),
                       ylim3b,
                       [firstval, lastval],
                       txt ] for j,i in enumerate(list(vardict.keys()))
})


hetero_latex(graphdict, vardict, gvar)


####size closures
graphdict = {
    #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
    str(i) + '_emp': [j*3 + 2,
                      'Overall Effects',
                      '\nEmployment Relative to Pre-Pandemic Employment Baseline' + groupvar + vardict[i],
                      ylab2,
                      filesuff2b + str(i),
                      ylim2b,
[firstval, lastval],
                      txt ] for j,i in enumerate(list(vardict.keys()))
}

numcols = 2
dflist = multigraphplot(graphdict, numcols)




####size wage
graphdict = {
    #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
    str(i) + '_wage': [j*3 + 1,
                      'Overall Effects',
                      '\nEmployment Relative to Pre-Pandemic Employment Baseline' + groupvar + vardict[i],
                      ylab3,
                      filesuff3b + str(i),
                      ylim3b,
[firstval, lastval],
                      txt ] for j,i in enumerate(list(vardict.keys()))
}

numcols = 2
dflist = multigraphplot(graphdict, numcols)


'''
ui state estimates

'''

#dummy for if state is high/low replacement rate
dfui = pd.read_csv(dataloc + 'rona/ui_replacement_rates_ganongetal.txt',sep='\t').iloc[:-1].rename(columns = {'State' : 'state'})
dfui['state'] = dfui['state'].apply(lambda x: x.strip())
for c in list(dfui.columns)[1:]:
    dfui[c] = pd.to_numeric(dfui[c].apply(lambda x: float(x.replace('%','').replace('(','').replace(')',''))))

dfui['fipsstate'] = dfui['state'].apply(lambda x: stateinfo(x.strip(), 'fips'))
bins = [-1, 141, 154, 1e999]
wagelist = ['<141%', '141-154%', '154%+']
dfui['replacement_bins']  = pd.cut(dfui['replacement rate with FPUC'], bins=bins, labels=wagelist)

tdict = dict(zip(dfui['fipsstate'], dfui['replacement_bins']))
df['replacement_bins'] = df['fips'].map(tdict)


gvar = 'replacement_bins'
vardict = {
'<141%':'$<141\%$',
    '154%+':'$154\%+$',

}
groupname = 'UI Replacement Rate in State'
groupvar = '\nby '+ groupname +' : '
filesuff1b = filesuff1 + gvar
filesuff2b = filesuff2 + gvar
filesuff3b = filesuff3 + gvar

fname = mostrecentfile('ANldb_pppv7_uistate.txt', resultsloc1, -1)
txt = open(fname).read().split('\n')
graphdict = {
    #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
    str(i) + '_emp': [j*3,
                      'Overall Effects',
                      '\nEmployment Relative to Pre-Pandemic Employment Baseline' + groupvar + vardict[i],
                      ylab1,
                      filesuff1b + str(i),
                      ylim1b,
[firstval, lastval],
                      txt ] for j,i in enumerate(list(vardict.keys()))
}

graphdict.update({
    # marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
    str(i) + '_wage': [j * 3 + 1,
                       'Overall Effects',
                       '\nWage Relative to Pre-Pandemic Employment Baseline' + groupvar + vardict[i],
                       ylab3,
                       filesuff3b + str(i),
                       ylim3b,
                       [firstval, lastval],
                       txt] for j, i in enumerate(list(vardict.keys()))
})

numcols = 2
dflist = multigraphplot(graphdict, numcols)



hetero_latex(graphdict, vardict, gvar)





#
#
#
# ############ wages
# gvar = 'avg_wages_bin'
# vardict = {
#     1 : '$<20$k',
# 2 : '20-40k',
# 3 : '40-60k',
# 4 : '60-80k',
# 5 : '$80+$k',
#
# }
# groupname = 'Wage Group'
# groupvar = '\nby '+ groupname +' : '
# filesuff1b = filesuff1 + gvar
# filesuff3b = filesuff3 + gvar
#
# ####this is either 20 or 19
# firstnum = 20
# graphdict = {
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#     str(i) + '_emp': [firstnum + j - 1, 'Overall Effects', '\nEmployment Relative to Pre-Pandemic Employment Baseline' + groupvar + vardict[i], ylab1,  filesuff1b + str(i), ylim1b, mostrecentfile('ANldb_pppv5_heterogen.txt', resultsloc1, numfile = -1) ] for j,i in enumerate(list(vardict.keys()))
# }
#
#
# graphdict.update({
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#     str(i) + '_wage': [firstnum + j, 'Overall Effects', '\nWages Relative to Pre-Pandemic Wage Baseline' + groupvar + vardict[i], ylab3,  filesuff3b + str(i), ylim3b, mostrecentfile('ANldb_pppv5_heterogen_wages.txt', resultsloc1, numfile = -1)  ] for j,i in enumerate(list(vardict.keys()))
# })
#
#
# hetero_latex(graphdict, vardict, gvar)

'''
poverty
'''
# gvar = 'poverty_cut'
# vardict = {
#     0 : '$<10\%$ Poverty',
# 10 : '10-15\% Poverty',
# 15 : '$>15\%$ Poverty',
# }
#
# firstnum = 0
# groupname = 'Poverty Rates'
# groupvar = '\nby '+ groupname +' : '
# filesuff1b = filesuff1 + gvar
# filesuff3b = filesuff3 + gvar
#
#
#
# graphdict = {
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#     str(i) + '_emp': [firstnum + j, 'Overall Effects', '\nEmployment Relative to Pre-Pandemic Employment Baseline' + groupvar + vardict[i], ylab1,  filesuff1b + str(i), ylim1b, mostrecentfile('ANldb_pppv5_heterogen.txt', resultsloc1, numfile = -1) ] for j,i in enumerate(list(vardict.keys()))
# }
#
#
# graphdict.update({
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#     str(i) + '_wage': [firstnum + j, 'Overall Effects', '\nWages Relative to Pre-Pandemic Wage Baseline' + groupvar + vardict[i], ylab3,  filesuff3b + str(i), ylim3b, mostrecentfile('ANldb_pppv5_heterogen_wages.txt', resultsloc1, numfile = -1)  ] for j,i in enumerate(list(vardict.keys()))
# })
#
#
# hetero_latex(graphdict, vardict, gvar)



'''
HHI
'''
# gvar = 'hhi_cut'
# vardict = {
#     0 : 'Competitive Marker',
# 1000 : 'Mid-Range Market',
# 2500 : 'Monopsonistic Market',
#
# }
#
# firstnum = 11
# groupname = 'HHI'
# groupvar = '\nby '+ groupname +' : '
# filesuff1b = filesuff1 + gvar
# filesuff3b = filesuff3 + gvar
#
#
#
# graphdict = {
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#     str(i) + '_emp': [firstnum + j, 'Overall Effects', '\nEmployment Relative to Pre-Pandemic Employment Baseline' + groupvar + vardict[i], ylab1,  filesuff1b + str(i), ylim1b, mostrecentfile('ANldb_pppv5_heterogen.txt', resultsloc1, numfile = -1) ] for j,i in enumerate(list(vardict.keys()))
# }
#
#
# graphdict.update({
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#     str(i) + '_wage': [firstnum + j, 'Overall Effects', '\nWages Relative to Pre-Pandemic Wage Baseline' + groupvar + vardict[i], ylab3,  filesuff3b + str(i), ylim3b, mostrecentfile('ANldb_pppv5_heterogen_wages.txt', resultsloc1, numfile = -1)  ] for j,i in enumerate(list(vardict.keys()))
# })
#
#
# hetero_latex(graphdict, vardict, gvar)



'''
bank distance
'''
# bins = [-1,.5,1,5,9e9]
# distbins = ['0', '1', '2', '3',]
# gvar = 'bank_dist_bins'
# vardict = {
#     0 : 'Less than half a mile',
# 1 : '.5-1 mile',
# 2 : '1-5 miles',
# 3 : '5+ miles',
#
# }
#
# firstnum = 7
# groupname = 'Distance from Nearest Bank'
# groupvar = '\nby '+ groupname +' : '
# filesuff1b = filesuff1 + gvar
# filesuff3b = filesuff3 + gvar
#
#
#
# graphdict = {
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#     str(i) + '_emp': [firstnum + j, 'Overall Effects', '\nEmployment Relative to Pre-Pandemic Employment Baseline' + groupvar + vardict[i], ylab1,  filesuff1b + str(i), ylim1b, mostrecentfile('ANldb_pppv5_heterogen.txt', resultsloc1, numfile = -1) ] for j,i in enumerate(list(vardict.keys()))
# }
#
#
# graphdict.update({
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#     str(i) + '_wage': [firstnum + j, 'Overall Effects', '\nWages Relative to Pre-Pandemic Wage Baseline' + groupvar + vardict[i], ylab3,  filesuff3b + str(i), ylim3b, mostrecentfile('ANldb_pppv5_heterogen_wages.txt', resultsloc1, numfile = -1)  ] for j,i in enumerate(list(vardict.keys()))
# })
#
#
# hetero_latex(graphdict, vardict, gvar)
#
#


'''
age
'''
#
# gvar = 'age_bins'
# vardict = {
#     0 : '$<6$ Years old',
# 6 : '6-11 years old',
# 11 : '11-21 years old',
# 21 : '$21+$ years old',
#
# }
#
# firstnum = 3
# groupname = 'Age of Establishment'
# groupvar = '\nby '+ groupname +' : '
# filesuff1b = filesuff1 + gvar
# filesuff3b = filesuff3 + gvar
#
#
#
# graphdict = {
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#     str(i) + '_emp': [firstnum + j, 'Overall Effects', '\nEmployment Relative to Pre-Pandemic Employment Baseline' + groupvar + vardict[i], ylab1,  filesuff1b + str(i), ylim1b, mostrecentfile('ANldb_pppv5_heterogen.txt', resultsloc1, numfile = -1) ] for j,i in enumerate(list(vardict.keys()))
# }
#
#
# graphdict.update({
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#     str(i) + '_wage': [firstnum + j, 'Overall Effects', '\nWages Relative to Pre-Pandemic Wage Baseline' + groupvar + vardict[i], ylab3,  filesuff3b + str(i), ylim3b, mostrecentfile('ANldb_pppv5_heterogen_wages.txt', resultsloc1, numfile = -1)  ] for j,i in enumerate(list(vardict.keys()))
# })
#
#
# hetero_latex(graphdict, vardict, gvar)



'''
getting telework
'''
# #get telework classification
# dfo = opennew('oes/oes_telework_2019', []).rename(columns = {'fipsstate' : 'fips',
#                                                                          'UDBNum' : 'ldb_num'})
# dfo['telework_pct'] = dfo['emp_telework'] / (dfo['emp_telework'] + dfo['emp_nottelework'])
# mc = ['ldb_num']
# for c in mc:
#     dfo[c] = pd.to_numeric(dfo[c].to_pandas())
# df = df.merge(dfo[mc + ['telework_pct']], on = mc, how = 'inner')
#
# bins = [-1,.25,.75,9e9]
# distbins = ['0', '1', '2', ]
#
# df['telework_pct_bins'] = pd.cut(df['telework_pct'].to_pandas(), bins=bins, labels = distbins).astype('str')
#
# gvar = 'telework_pct_bins'
# vardict = {
#     0 : '$<25\%$ of Estab. in Telework Occupations',
# 1 : '25-75\% of Estab. in Telework Occupations',
# 2 : '$75+\%$ of Estab. in Telework Occupations',
#
#
# }
#
#
#
# kvars = ['emp_ppp', 'pppamt_calc', 'wage_ppp', 'aaemp_19', kwvar]
# #### get ppp and emp group totals from main file
# dft = df[kvars + [gvar]].to_pandas().groupby(gvar, as_index=False)[kvars].sum()
# dft['tot_ppp'] = dft['pppamt_calc'].sum()
# dft['pct_pppamt'] = 100 * dft['pppamt_calc'] / dft['tot_ppp']
# dft['pct_emp_ppp'] = 100 * dft['emp_ppp'] / dft['aaemp_19']
# dft['pct_wage_ppp'] = 100 * dft['wage_ppp'] / (dft[kwvar])
#
#
# firstnum = 0
# groupname = 'Telework Category'
# groupvar = '\nby '+ groupname +' : '
# filesuff1b = filesuff1 + gvar
# filesuff3b = filesuff3 + gvar
#
#
#
# graphdict = {
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#     str(i) + '_emp': [(firstnum + j)*3+1, 'Overall Effects', '\nEmployment Relative to Pre-Pandemic Employment Baseline' + groupvar + vardict[i], ylab1,  filesuff1b + str(i), ylim1b, mostrecentfile('ANldb_pppv5_soc.txt', resultsloc1, numfile = -1) ] for j,i in enumerate(list(vardict.keys()))
# }
#
#
# graphdict.update({
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#     str(i) + '_wage': [(firstnum + j)*3, 'Overall Effects', '\nWages Relative to Pre-Pandemic Wage Baseline' + groupvar + vardict[i], ylab3,  filesuff3b + str(i), ylim3b, mostrecentfile('ANldb_pppv5_soc.txt', resultsloc1, numfile = -1)  ] for j,i in enumerate(list(vardict.keys()))
# })
#
#
# hetero_latex(graphdict, vardict, gvar)


'''
this is for robustness checks
'''

# ylab1 = 'Effect on Employment %\n Relative to 2019 Employment'
# ylab1b = 'Effect on Employment %\n Relative to 2017 Employment'
# ylab2 = 'Effect on Probability of Closing'
# filesuff1b = 'emp_2019'
# filesuff1 = 'emp_removesmall'
# filesuff2 = 'closed_removesmall'
# ylim1 = (-3,8)
# ylim2 = (-5,3)
# graphdict = {
#     #marker : [what number of results in text file, size class value, plttitle, ylab, filesuff, ylim]
#     2 : [4, 'Overall Effects', '\nEmployment Relative to 2017 Employment', ylab1b,  filesuff1b, ylim1 ],
#     1 : [0, 'Overall Effects', '\nEmployment Relative to 2019 Employment', ylab1,  filesuff1, ylim1 ],
# 0 : [1, 'Overall Effects', '\nProbability of Establishment Closure', ylab2,  filesuff2, ylim2 ],
# }
#
# dflist = {}
# for k,vals in graphdict.items():
#     if k>=2:
#         fname = mostrecentfile('ANldb_pppv5_2019.txt', resultsloc1)
#         txt = open(fname).read().split('\n')
#     else:
#         fname = mostrecentfile('ANldb_pppv5_removesmall.txt', resultsloc1)
#         txt = open(fname).read().split('\n')
#     dflist.update({k : mkgraph(txt, vals)})
#
#
# ###latex table code
# dft = pd.concat(list(dflist.values()), axis=1)
# ncols = len(dft.columns)
# dft.columns = [i for i in range(ncols)]
# dft.reset_index(inplace=True)
# txt1 = "\multirow{2}*{"
#
# for i in range(len(dft)):
#     txt = txt1
#     rown = dft.iloc[i]['event_time']
#     txt = txt + str(rown) + "}&ATT&"
#     for j in range(0,ncols,2):
#         tval = str(round(dft.iloc[i][j],2))
#         txt = txt + tval + "&"
#     print(txt[:-1] + r"\\")
#     txt = "&[95\% C.I.]&"
#     for j in range(0,ncols):
#         if j%2==1:
#             tvallb = str(round(dft.iloc[i][j-1] - dft.iloc[i][j],1))
#             tvalub = str(round(dft.iloc[i][j - 1] + dft.iloc[i][j],1))
#             tval = "[" + tvallb + ', '+tvalub + "]"
#             txt = txt + tval + "&"
#     print(txt[:-1] + r"\\")

'''

'''

'''
adding in SOC
'''
# dfo = cudf.read_csv(dataloc + 'oes/oes2digit_2019.csv')
# dfo['ldb_num'] = pd.to_numeric(dfo['UDBNum'].to_pandas())
# soccols = ['emp_occ_' + str(i) for i in range(11,55,2)]
# dfo['tot_emp'] = dfo[soccols].to_pandas().sum(axis=1)
# for i in range(11,55,2):
#     cond = ((dfo['emp_occ_' + str(i)].to_pandas() / dfo['tot_emp'].to_pandas()) > .05)
#     dfo['D5pct_occ_' + str(i)] = np.where(cond, 1, 0)
#
# kcols = ['ldb_num'] + ['D5pct_occ_' + str(i) for i in range(11,55,2)]
# dfo = dfo[kcols]
#
# mc = ['ldb_num']
# df = df.merge(dfo, on = mc, how = 'inner')
#
# #get telework classification
# dfo = opennew('oes/oes_telework_2019', []).rename(columns = {'fipsstate' : 'fips',
#                                                                          'UDBNum' : 'ldb_num'})
# dfo['telework_pct'] = dfo['emp_telework'] / (dfo['emp_telework'] + dfo['emp_nottelework'])
# mc = ['ldb_num']
# for c in mc:
#     dfo[c] = pd.to_numeric(dfo[c].to_pandas())
# df = df.merge(dfo[mc + ['telework_pct']], on = mc, how = 'inner')
#
# bins = [-1,.25,.75,9e9]
# distbins = ['0', '1', '2', ]
#
# df['telework_pct_bins'] = pd.cut(df['telework_pct'].to_pandas(), bins=bins, labels = distbins).astype('str')
#
# df.to_pandas().to_csv(dataloc + 'pppfiles/' + filename + '_OES.csv')
#

